"""
Phase 3: Phillips Curve Flattening
Does aging flatten the Phillips curve? Test via Z×output_gap interactions.
Key hypothesis: Z_1_x_output_gap < 0 → aging flattens Phillips curve.
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helper ───────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 50:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    """Significance stars."""
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    """Format coefficient as 'coef(se)***'."""
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    """Build a pipe-format markdown table."""
    lines = [f"# {title}\n"]
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 3: PHILLIPS CURVE FLATTENING")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    pre_gfc = panel[panel["year"] <= 2007].copy()
    post_gfc = panel[panel["year"] >= 2008].copy()

    # ── TABLE 3: Phillips Curve with Demographic Interactions ────────────
    print("\n── Table 3: Phillips Curve with Demographic Interactions ──")

    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_pc = ["rgdp_growth", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag", "inflation_lag"]

    # Model (a)-(d): Z_1_x_output_gap
    rhs_base = ["output_gap"] + z_vars + ["Z_1_x_output_gap"] + controls_pc

    # Model (e): old_dep_x_output_gap instead
    rhs_old = ["output_gap"] + z_vars + ["old_dep_x_output_gap"] + controls_pc

    # Model (f): working_age_share_x_output_gap instead
    rhs_was = ["output_gap"] + z_vars + ["working_age_share_x_output_gap"] + controls_pc

    # Union of all row variables for display
    all_row_vars = list(dict.fromkeys(
        ["output_gap"] + z_vars +
        ["Z_1_x_output_gap", "old_dep_x_output_gap", "working_age_share_x_output_gap"] +
        controls_pc
    ))

    samples_ad = [
        ("(a) Full", panel, rhs_base),
        ("(b) OECD", oecd, rhs_base),
        ("(c) Pre-GFC", pre_gfc, rhs_base),
        ("(d) Post-GFC", post_gfc, rhs_base),
        ("(e) old_dep interact", panel, rhs_old),
        ("(f) WAS interact", panel, rhs_was),
    ]

    results_t3 = []
    col_labels_t3 = []
    for label, samp, rhs in samples_ad:
        res = run_model(samp, "inflation", rhs)
        results_t3.append(res)
        col_labels_t3.append(label)
        if res:
            # Print key interaction coefficient
            for ivar in ["Z_1_x_output_gap", "old_dep_x_output_gap",
                         "working_age_share_x_output_gap"]:
                if ivar in res["coefs"]:
                    c = res["coefs"][ivar]
                    p = res["pvals"][ivar]
                    print(f"  {label}: {ivar} = {c:.3f} (p={p:.3f})")
        else:
            print(f"  {label}: no result")

    footer_t3 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t3]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t3]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t3]),
    ]

    md_t3 = build_markdown_table(
        "Table 3: Phillips Curve with Demographic Interactions",
        all_row_vars, col_labels_t3, results_t3, footer_t3,
    )
    print("\n" + md_t3)
    t3_path = TABLE_DIR / "phase3_table3_phillips_interactions.md"
    t3_path.write_text(md_t3)
    print(f"\nSaved: {t3_path}")

    # ── TABLE 4: Phillips Curve Slope by Z₁ Tercile ─────────────────────
    print("\n── Table 4: Phillips Curve Slope by Z₁ Tercile ──")

    # Compute country-median Z_1 and split into terciles
    z1_median = panel.groupby("iso3")["Z_1"].median()
    t1_cut = z1_median.quantile(1 / 3)
    t2_cut = z1_median.quantile(2 / 3)

    tercile_map = z1_median.apply(
        lambda x: "young" if x <= t1_cut else ("middle" if x <= t2_cut else "old")
    )
    panel["z1_tercile"] = panel["iso3"].map(tercile_map)

    # Use rgdp_growth consistently across all terciles for comparability.
    # Output_gap is OECD-only and concentrates in the old tercile, making
    # cross-tercile slope comparisons invalid if measures differ.
    controls_tercile = ["fiscal_bal_gdp", "kaopen", "nfa_gdp_lag", "inflation_lag"]
    rhs_tercile = ["rgdp_growth"] + controls_tercile

    tercile_labels = ["Young (low Z₁)", "Middle", "Old (high Z₁)"]
    tercile_keys = ["young", "middle", "old"]

    results_t4 = []
    col_labels_t4 = []
    for label, tkey in zip(tercile_labels, tercile_keys):
        sub = panel[panel["z1_tercile"] == tkey].copy()
        nc = sub["iso3"].nunique()
        res = run_model(sub, "inflation", rhs_tercile)
        results_t4.append(res)
        col_labels_t4.append(label)
        if res:
            ac = res["coefs"]["rgdp_growth"]
            ap = res["pvals"]["rgdp_growth"]
            print(f"  {label} ({nc} countries): rgdp_growth = {ac:.3f} "
                  f"(p={ap:.3f}), N={res['nobs']}")
        else:
            print(f"  {label} ({nc} countries): no result")

    all_tercile_vars = rhs_tercile

    footer_t4 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t4]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t4]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t4]),
    ]

    md_t4 = build_markdown_table(
        "Table 4: Phillips Curve Slope by Z₁ Tercile",
        all_tercile_vars, col_labels_t4, results_t4, footer_t4,
    )
    md_t4 += "\n\n*Note: rgdp_growth used as the activity proxy across all terciles to ensure comparability. Output_gap is available only for OECD countries, which concentrate in the old tercile.*"
    print("\n" + md_t4)
    t4_path = TABLE_DIR / "phase3_table4_phillips_terciles.md"
    t4_path.write_text(md_t4)
    print(f"\nSaved: {t4_path}")

    # ── Summary of key finding ───────────────────────────────────────────
    print("\n── Key Test: Does aging flatten the Phillips curve? ──")
    res_a = results_t3[0]  # Full sample with Z_1_x_output_gap
    if res_a and "Z_1_x_output_gap" in res_a["coefs"]:
        c = res_a["coefs"]["Z_1_x_output_gap"]
        p = res_a["pvals"]["Z_1_x_output_gap"]
        direction = "FLATTENS" if c < 0 else "STEEPENS"
        sig = "significant" if p < 0.10 else "NOT significant"
        print(f"  Z_1 x output_gap = {c:.3f} (p={p:.3f})")
        print(f"  → Aging {direction} the Phillips curve ({sig})")
    else:
        print("  Could not estimate key interaction.")

    # Compare tercile activity slopes
    if results_t4[0] and results_t4[2]:
        # Find the activity variable in each tercile result
        young_res = results_t4[0]
        old_res = results_t4[2]
        for avar in ["output_gap", "rgdp_growth"]:
            if avar in young_res["coefs"]:
                young_slope = young_res["coefs"][avar]
                young_var = avar
                break
        for avar in ["output_gap", "rgdp_growth"]:
            if avar in old_res["coefs"]:
                old_slope = old_res["coefs"][avar]
                old_var = avar
                break
        print(f"\n  Tercile comparison:")
        print(f"    Young tercile {young_var} slope: {young_slope:.3f}")
        print(f"    Old tercile {old_var} slope:   {old_slope:.3f}")
        if abs(old_slope) < abs(young_slope):
            print(f"    → Consistent with flattening (|old| < |young|)")
        else:
            print(f"    → NOT consistent with flattening (|old| >= |young|)")

    print("\nPhase 3 complete.")


if __name__ == "__main__":
    main()
